"""
coopernaut receiver policies
"""
import os
import yaml
import math
import torch
import random
import argparse
import numpy as np
from omegaconf import OmegaConf
from models.point_transformer import PointTransformer
from models.cooperative_point_transformer import CooperativePointTransformer

from policies.utils.episode_memory import EpisodeMemory
from policies.base_policy import BasePolicy
from policies.utils.coopernaut_utils import get_config
from policies.utils.lidar_processor import LidarProcessorConfig, filter_lidar_by_boundary, lidar_to_bev_v2, pad_or_truncate, pc_to_car_alignment, Sparse_Quantize, TransformMatrix_WorldCoords

class CoopernautReceiverPolicy(BasePolicy):
    def __init__(self, agent_id, config_file):
        self.agent_id = agent_id
        self.episode_return = 0
        self.episode_memory = EpisodeMemory()
        self.config, self.model_path = get_config(config_file, num_checkpoint=105)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if 'cpt' in self.model_path:
            self.cpt = True
            self.model = CooperativePointTransformer(self.config).to(self.device)
        else:
            self.cpt = False
            self.model = PointTransformer(self.config).to(self.device)
        self.model.eval()
        self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
        self.max_num_neighbors = self.config.get('max_num_neighbors')
        self.npoints = self.config.npoints
        self.ego_inverse_transform = None
        self.step_count = 0

    def reset(self):
        self.ego_inverse_transform = None
        self.step_count = 0
        self.episode_return = 0

    def observe(self, obs, reward, terminated, truncated, info):
        self.observation = obs
        self.episode_return += reward

    def act(self):
        ego_lidar = self.observation["LIDAR"][1][:,:3]
        ego_transform = self.observation["ego_transform"]
        ego_bbox = self.observation["ego_bbox"]
        ego_speed = self.observation["ego_speed"] / 30.0
        ego_speed = torch.tensor([ego_speed]).float().to(self.device)
        messages = self.observation["received_messages"]
        # generate ego representation
        ego_lidar, self.ego_inverse_transform = self.process_lidar(ego_lidar, ego_transform, ego_bbox)
        other_points, other_xyz = self.process_messages(messages)
        if self.cpt:
            ego_repr = self.model.backbone(ego_lidar)
            # aggregate messages from other agents
            aggr_repr = self.model.aggregate(ego_repr, other_points, other_xyz)
            # generate control
            control = self.model.control(aggr_repr, ego_speed)
        else:
            control = self.model(ego_lidar, ego_speed)
        self.step_count += 1
        return parse_action(control)

    def process_messages(self, messages):
        num_neighbors = 0
        aggr_points = None
        aggr_xyz = None
        for sender_id, message in messages.items():
            other_transform = np.array(message.transform)
            other_transform = np.array([np.matmul(self.ego_inverse_transform, other_transform)]) 
            other_points = np.array(message.representation)
            other_xyz = np.array(message.xyz)
            other_transform = torch.from_numpy(other_transform).float().to(self.device)
            other_points = torch.from_numpy(other_points).float().to(self.device)
            other_xyz = torch.from_numpy(other_xyz).float().to(self.device)
            other_xyz = torch.transpose(torch.matmul(other_transform[:,:3,:3],torch.transpose(other_xyz, 1, 2)), 1,2)+torch.cat(other_points.shape[1]*[other_transform[:,:3,3].unsqueeze(1)],dim=1)
            if aggr_points is None:
                aggr_points = other_points
                aggr_xyz = other_xyz
            else:
                aggr_points = torch.cat((aggr_points, other_points), dim=1)
                aggr_xyz = torch.cat((aggr_xyz, other_xyz), dim=1)
            num_neighbors += 1
        while num_neighbors < self.max_num_neighbors:
            other_lidar = np.array([np.zeros((self.npoints,3))])
            other_transform = np.array([np.eye(4)])
            other_lidar = torch.from_numpy(other_lidar).float().to(self.device)
            other_transform = torch.from_numpy(other_transform).float().to(self.device)
            other_points, other_xyz, _ = self.model.backbone_other(other_lidar)
            other_xyz = torch.transpose(torch.matmul(other_transform[:,:3,:3],torch.transpose(other_xyz, 1, 2)), 1,2)+torch.cat(other_points.shape[1]*[other_transform[:,:3,3].unsqueeze(1)],dim=1)
            if aggr_points is None:
                aggr_points = other_points
                aggr_xyz = other_xyz
            else:
                aggr_points = torch.cat((aggr_points, other_points), dim=1)
                aggr_xyz = torch.cat((aggr_xyz, other_xyz), dim=1)
            num_neighbors += 1
        return aggr_points, aggr_xyz

    def get_episode_return(self):
        return self.episode_return

    def record_episode(self):
        pass

    def learn(self):
        pass

    def process_lidar(self, ego_lidar, ego_transform, ego_bbox):
        #---
        from matplotlib import pyplot as plt
        #bev_show = lidar_to_bev_v2(ego_lidar)
        #bev_show = np.mean(bev_show, axis=2)
        #plt.imshow(bev_show)
        #plt.show()
        #---
        if self.step_count < 0:
            fig = plt.figure()
            ax = plt.axes(projection='3d')
            ax.scatter3D(ego_lidar[:,0], ego_lidar[:,1], ego_lidar[:,2], c=ego_lidar[:,2], cmap='Greens')
            plt.show()
        #---

        ego_z_compensation = 2*abs(ego_bbox.extent.z) + 0.5 # LidarRoofTopDistance
        ego_location_z = ego_transform.location.z
        ego_transform.location.z = 0
        ego_transform = TransformMatrix_WorldCoords(ego_transform)
        ego_inverse = np.array(ego_transform.inversematrix())
        ego_lidar[:,2] = ego_lidar[:,2] + abs(ego_z_compensation)
        ego_lidar = pc_to_car_alignment(ego_lidar)
        ego_lidar = Sparse_Quantize(ego_lidar)
        ego_lidar = np.unique(ego_lidar, axis=0)
        ego_lidar = pad_or_truncate(ego_lidar, self.npoints)
        ego_lidar = torch.from_numpy(np.array([ego_lidar])).float().to(self.device)
        return ego_lidar, ego_inverse

def parse_action(control):
    throttle, brake, steer = control
    throttle = throttle.cpu().detach().numpy()
    brake = brake.cpu().detach().numpy()
    steer = steer.cpu().detach().numpy()
    action = {}
    action['steer'] = np.clip(steer, -1.0, 1.0)
    action['throttle'] = np.clip(throttle, 0.0, 1.0)
    action['brake'] = np.clip(brake, 0.0, 1.0)
    print(action)
    return action

